import jax
from jax import random
from jax import numpy as jnp
from PyCmpltrtok.common import sep

key = random.PRNGKey(0)


def f(x):
    return jnp.dot(x.T, x) / 2.0


sep('ones')
v = jnp.ones((4,))
print('v', v)
fv = f(v)
print('fv', fv)
xgrad = jax.grad(f)(v)
print('xgrad', xgrad)

sep('normal value')
v = random.normal(key, (4,))
print('v', v)
print('Gradient of f taken at v')
xgrad = jax.grad(f)(v)
print('xgrad', xgrad)
